-
-
Notifications
You must be signed in to change notification settings - Fork 83
Make gemm_strided_batched!
work with PermutedDimsArray
s
#664
Conversation
gemm_strided_batched!
work with PermutedDimsArray
s
I also have a crude attempt at benchmarking things, on an very slow card. Would be curious whether the gain is similar on faster cards (6x on size 128 here, 4x on size 256). using CuArrays, NNlib, BenchmarkTools
xx, yy = randn(Float32, 128,128,128), randn(Float32, 128,128,128); cxx = cu(xx); cyy = cu(yy);
CuArrays.allowscalar(false)
# simple batched_cu, not inserting transposes, and including some wrong answers!
for a in 1:3, b in 1:3, c in 1:3
perm_ = (a,b,c)
isperm(perm_) || continue
eager = @belapsed CuArrays.@sync batched_cu(permutedims($cxx, $perm_), permutedims($cyy, $perm_))
try
lazy = @belapsed CuArrays.@sync batched_cu(PermutedDimsArray($cxx,$perm_), PermutedDimsArray($cyy,$perm_))
@info "perm = $perm_ times:" eager lazy eager/lazy
catch err
@error "perm = $perm_ gives an error" err
end
end
# ┌ Info: perm = (1, 2, 3) times:
# │ eager = 0.011564265
# │ lazy = 0.001926885
# └ eager / lazy = 6.001533563238077
# ┌ Info: perm = (1, 3, 2) times:
# │ eager = 0.011582102
# │ lazy = 0.001946393
# └ eager / lazy = 5.9505464723722294
# ** On entry to SGEMM parameter number 8 had an illegal value
# ┌ Error: perm = (2, 1, 3) gives an error
# │ err = CUBLASError: an invalid value was used as an argument (code 7, CUBLAS_STATUS_INVALID_VALUE)
# └ @ Main REPL[97]:9
# ...
# sanity checks?
@btime CuArrays.@sync batched_cu($cxx, $cyy); # 1.948 ms (17 allocations: 512 bytes)
@btime CuArrays.@sync batched_mul($cxx, $cyy); # 1.922 ms (17 allocations: 512 bytes)
@btime CuArrays.@sync permutedims($cxx, (1,2,3)); # 4.997 ms (89 allocations: 2.86 KiB)
@btime CuArrays.@sync permutedims($cxx, (3,2,1)); # 8.670 ms (89 allocations: 2.86 KiB) |
cc @haampie, IIUC you were looking at reworking BLAS wrappers for use with Base's strided abstractions too. |
Two approaches to how point 3 should work are FluxML/NNlib.jl#187 (using ArrayLayouts) and FluxML/NNlib.jl#191 (directly checking Point 4 is now also addressed. Allowing |
@@ -937,15 +937,16 @@ for (fname, elty) in | |||
function gemm_strided_batched!(transA::Char, | |||
transB::Char, | |||
alpha::($elty), | |||
A::CuArray{$elty, 3}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think when I tried to wrap this I intended to use this as a low level API, now since both CUDAnative and CuArrays changed a lot, maybe we need a more bare wrapper (like a pointer type CuPtr
) directly wraps the CUBLAS API? then it'd be more elegant to have a higher level wrapper for different Julia array types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can NNlib.batched_mul
be this higher-level wrapper? FluxML/NNlib.jl#191 makes it more flexible, and able to dispatch according to the underlying data.
And what can't you do with this wrapper (which works on any AbstractArray for which this pointer exists) which you could do with a different one?
Now copied to JuliaGPU/CUDA.jl#539 |
CUBLAS's
gemm_strided_batched!
accepts strides apart from the ordered ones of aCuArray
. This is potentially useful as it saves on performingpermutedims
before operations, and the obvious way to expose this is to let it acceptPermutedDimsArray
s. This PR is an attempt to make that happen... while trying to avoid problems with wrappers & dispatch.I presume that by the time anyone is calling
CuArrays.CUBLAS.gemm_strided_batched!
, they are well aware that the operation is to be done by CUBLAS, and not expecting dispatch to redirect to the CPU if it's an Array. So I changed this to accept anyAbstractArray
.It needs a pointer
CuPtr{T}
which wasn't defined forPermutedDimsArray
, I've added a generic definition which unwraps this recursively. (Not sure where such a defn should actually live.) Then it can run.Ideally
NNlib.batched_mul
will be a friendly function which dispatches every kind of array to the best routine. This is a bit of a mess right now, it needs Improvements to batched_mul, including PermutedDimsArray FluxML/NNlib.jl#187 (which similarly tries to allow forPermutedDimsArray
on the CPU, but maybe not all the same ones) so don't look too hard just yet. It also needssimilar()
to unwrap and see theCuArray
, similar(PermutedDimsArray(::CuArray)) isa Array #658 fixed by Use parent for similar(::PermutedDimsArray) JuliaLang/julia#35304 + Compat.jl I suppose, but for copied in here to try things out.This still isn't quite the most general thing, as there are matrix * 3-tensor contractions which should perhaps be done by this routine, e.g. the table here with something like
strides(A)==(1,N,1)
.Here's a test of steps 1 & 2:
which gives
The failures for
perm = (2, 1, 3)
etc. are fine, we should adjust the permutation & use the'T'
variant, and my messy NNlib code tries to do this (and needs to do it on CPU too).The wrong answer for
perm = (3, 2, 1)
is more curious, note that it's not completely wrong:Perhaps this is just some mistake with strides, or perhaps this case is in fact not supported, I'm not sure.
c.c. @Roger-luo who has also messed with these things.